#!/usr/bin/python

# FIXME use subparsers in argparse

import numpy,os,os.path,sys,time,fcntl,multiprocessing,argparse,codecs,traceback
import random as pyrandom
from collections import Counter,defaultdict
from pylab import *
import tables
from tables import openFile,Filters,Int32Atom,Float32Atom,Int64Atom
import ocrolib
from ocrolib.ligatures import lig
from ocrolib import h5utils,improc

def split_re(s):
    assert s[0]==s[-1],"specify regex substitutions as /old/new/"
    r = s[1:-1].split(s[0])
    assert len(r)==2,"specify regex substitutions as /old/new/ or !old!new! etc."
    return r

def chunks(n,c):
    """Iterate over the range 0...n in chunks of size c."""
    i = 0
    while i<n:
        j = min(n,i+c)
        yield (i,j)
        i = j

if len(sys.argv)<2:
    print """usage:

ocropus-db cat --help
ocropus-db shuffle --help
ocrpous-db predict --help

Try "ocropus cat --help" etc. for more info.
"""
    sys.exit(0)

if sys.argv[1]=="cat":
    parser = argparse.ArgumentParser(description = "Concatenate data files.")
    parser.add_argument('inputs',nargs="+")
    parser.add_argument('-N','--nsamples',type=int,default=int(1e9),help="copy at most this many samples")
    parser.add_argument('-s','--select',default=None,help="select only classes matching the regular expression")
    parser.add_argument('-r','--regex',nargs="*",default=[],help="perform regular expression replacements on classes; /old/new/")
    parser.add_argument('-o','--output',help="output file")
    args = parser.parse_args(sys.argv[2:])
    select = None
    if args.select is not None:
        select = re.compile('^'+args.pattern+'$')
    with openFile(args.inputs[0]) as db:
        eshape = db.root.patches[0].shape
    print "element shape",eshape
    replacements = [split_re(r) for r in args.regex]
    with openFile(args.output,"w") as odb:
        h5utils.log(odb,str(sys.argv))
        h5utils.create_earray(odb,"patches",eshape,'f')
        h5utils.create_earray(odb,"classes",(),'f')
        sizemode = None
        for fname in args.inputs:
            with openFile(fname) as db:
                sizemode = db.getNodeAttr("/","sizemode")
                odb.setNodeAttr("/","sizemode",sizemode)
                h5utils.log_copy(db,odb)
                nsamples = len(db.root.classes)
                for start,end in chunks(nsamples,10000):
                    print "%9d %9d %s"%(len(odb.root.classes),start,fname)
                    if len(odb.root.classes)>args.nsamples: break
                    data = db.root.patches[start:end]
                    clss = db.root.classes[start:end]
                    if select is not None:
                        good = array([(cls>=32 and select.match(lig.chr(cls)) is not None) for cls in clss],bool)
                        assert len(good)==len(data)
                        data = data[good]
                        clss = clss[good]
                    if replacements!=[]:
                        for i,c in enumerate(clss):
                            for old,new in replacements:
                                clss[i] = lig.ord(re.sub(old,new,lig.chr(clss[i])))
                    n = min(end-start,args.nsamples-len(odb.root.patches))
                    if n==0: break
                    odb.root.patches.append(data[:n])
                    odb.root.classes.append(clss[:n])
            if len(odb.root.classes)>=args.nsamples: break
    sys.exit(0)

if sys.argv[1]=="shuffle":
    parser = argparse.ArgumentParser(description = "Concatenate data files.")
    parser.add_argument('input',help="input database")
    parser.add_argument('-N','--nsamples',type=int,default=int(1e9),help="copy at most this many samples")
    parser.add_argument('-o','--output',help="output database")
    args = parser.parse_args(sys.argv[2:])
    with openFile(args.input) as db:
        eshape = db.root.patches[0].shape
        sizemode = db.getNodeAttr("/","sizemode")
    print "element shape",eshape
    with openFile(args.output,"w") as odb:
        odb.setNodeAttr("/","sizemode",sizemode)
        h5utils.log(odb,str(sys.argv))
        h5utils.create_earray(odb,"patches",eshape,'f')
        h5utils.create_earray(odb,"classes",(),'f')
        with openFile(args.input) as db:
            h5utils.log_copy(db,odb)
            nsamples = min(len(db.root.classes),args.nsamples)
            samples = pyrandom.sample(xrange(len(db.root.classes)),nsamples)
            for pos,i in enumerate(samples):
                if pos%1000==0: print pos
                odb.root.patches.append([array(db.root.patches[i])])
                odb.root.classes.append([db.root.classes[i]])
    sys.exit(0)
    
if sys.argv[1]=="predict":
    parser = argparse.ArgumentParser(description = "Concatenate data files.")
    parser.add_argument('input',help="input database")
    parser.add_argument('-t','--testset',type=int,default=-1,help="use testset sequence t (-1=use all samples)")
    parser.add_argument('-m','--model',default=None,help="model to be evaluated")
    parser.add_argument('-N','--nsamples',type=int,default=int(1e9),help="copy at most this many samples")
    parser.add_argument('--show',action="store_true")
    args = parser.parse_args(sys.argv[2:])
    cmodel = ocrolib.ocropus_find_file(args.model)
    #print "loading",cmodel
    cmodel = ocrolib.load_component(cmodel)
    confusion = Counter()
    with openFile(args.input) as db:
        nsamples = min(len(db.root.classes),args.nsamples)
        samples = pyrandom.sample(xrange(len(db.root.classes)),nsamples)
        for i in samples:
            if args.testset>=0 and not ocrolib.testset(i,sequence=args.testset): continue
            patch = db.root.patches[i]
            outputs = cmodel.coutputs(patch)
            actual = lig.chr(db.root.classes[i])
            pred = outputs[0][0] if len(outputs)>0 else "~"
            if args.show:
                clf(); gray(); ion()
                subplot(211); imshow(patch)
                xlabel("%s"%actual)
                subplot(212); imshow(cmodel.splitter.center(patch)[1].reshape(32,32))
                xlabel("%s"%outputs[:2])
                ginput(1,99)
            confusion[(actual,pred)] += 1
    errs = sum([v for (a,p),v in confusion.items() if a!=p])
    total = sum(confusion.values())
    print "ERROR",errs,total,errs*100.0/total,args.model
    sys.exit(0)

def tess_readchars(fname,boxfile=None,d=1,pad=2):
    """Read characters in Tesseract training format."""
    # os.system("convert '%s' /tmp/image.png"%fname)
    image = ocrolib.read_image_gray(fname)
    image -= amin(image)
    image /= amax(image)
    h,w = image.shape
    if boxfile is None:
        base,_ = ocrolib.allsplitext(fname)
        boxfile = base+".box"
    with codecs.open(boxfile,"r","utf8") as stream:
        lineno = 0
        for lineno,line in enumerate(stream.readlines()):
            try:
                f = line.split()
                c = f[0]
		try:
		    x0,y0,x1,y1 = [int(x) for x in f[1:5]]
		except:
                    print boxfile,":",lineno,": syntax error"
                    print "   ",line
                    continue
                if y1-y0>200 or x1-x0>200:
                    print boxfile,":",lineno,": bad box dimensions"
                    print "   ",line
                    continue
                cimage = image[h-y1-1-d:h-y0-1+d,x0-d:x1+1].copy()
                # if "georgiai" in fname and lineno==1682: raise Exception("debug")
                cimage /= amax(cimage)
                cimage = improc.pad_by(1-cimage,pad)
                yield c,cimage
                lineno += 1
            except:
                traceback.print_exc()
                print fname,lineno,line

if sys.argv[1]=="tess2h5":
    parser = argparse.ArgumentParser(description = "Convert Tesseract box files to OCRopus databases.")
    parser.add_argument('inputs',nargs='+',help="input database")
    parser.add_argument('-d','--size',type=int,default=32)
    parser.add_argument('-o','--output',help="output database")
    args = parser.parse_args(sys.argv[2:])
    print args.inputs
    d = args.size
    count = 0
    with tables.openFile(args.output,"w") as odb:
        odb.setNodeAttr("/","sizemode","perchar")
        patches = h5utils.create_earray(odb,"patches",(d,d),'f')
        classes = h5utils.create_earray(odb,"classes",(),'i')
        for fname in args.inputs:
            print "===",fname
            lineno = 0
            for c,cimage in tess_readchars(fname):
                cimage = improc.classifier_normalize(cimage,size=d)
                patches.append([cimage])
                classes.append([lig.ord(c)])
                count += 1
                lineno += 1
    sys.exit(0)
